Skip to content

Commit bb0530f

Browse files
eonofreyfacebook-github-bot
authored andcommitted
Consolidate ax/adapter tests (#5006)
Summary: Part of a 19-diff stack to consolidate repetitive tests across Ax and PTS using `subTest`. Consolidate 8 test files in ax/adapter/ and ax/adapter/transforms/ — adds subTest parameterization to torch adapter, trial-as-task transform, logit transform, and objective-as-constraint tests. Differential Revision: D95603401
1 parent 1054802 commit bb0530f

8 files changed

Lines changed: 145 additions & 212 deletions

ax/adapter/tests/test_adapter_utils.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -313,33 +313,16 @@ def get_adapter(min: float, max: float) -> TorchAdapter:
313313

314314
def test_arm_to_np_array(self) -> None:
315315
# Test extracting target point from arm with valid parameters
316-
317-
# Setup: create arm with target parameter values
318316
target_arm = Arm(parameters={"x1": 0.5, "x2": 1.5, "x3": 2.5})
319-
parameters = ["x1", "x2", "x3"]
320-
321-
# Execute: extract target point
322-
actual = arm_to_np_array(arm=target_arm, parameters=parameters)
323-
324-
# Assert: confirm extracted values match expected order
325-
expected = np.array([0.5, 1.5, 2.5])
326-
self.assertIsNotNone(actual)
327-
np.testing.assert_array_equal(actual, expected)
328-
329-
def test_extract_arm_to_np_array_different_parameter_order(self) -> None:
330-
# Test extracting target point with different parameter ordering
331-
332-
# Setup: create arm and specify parameters in different order
333-
target_arm = Arm(parameters={"x1": 0.5, "x2": 1.5, "x3": 2.5})
334-
parameters = ["x3", "x1", "x2"]
335-
336-
# Execute: extract target point
337-
actual = arm_to_np_array(arm=target_arm, parameters=parameters)
338-
339-
# Assert: confirm values are extracted in specified parameter order
340-
expected = np.array([2.5, 0.5, 1.5])
341-
self.assertIsNotNone(actual)
342-
np.testing.assert_array_equal(actual, expected)
317+
cases = [
318+
(["x1", "x2", "x3"], np.array([0.5, 1.5, 2.5])),
319+
(["x3", "x1", "x2"], np.array([2.5, 0.5, 1.5])),
320+
]
321+
for parameters, expected in cases:
322+
with self.subTest(parameters=parameters):
323+
actual = arm_to_np_array(arm=target_arm, parameters=parameters)
324+
self.assertIsNotNone(actual)
325+
np.testing.assert_array_equal(actual, expected)
343326

344327
def test_arm_to_np_array_none(self) -> None:
345328
# Test that None is returned when target_arm is None

ax/adapter/tests/test_base_adapter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,10 @@ def _test_init_with_data(self, multi_objective: bool) -> None:
193193
search_space=search_space, experiment_data=experiment_data
194194
)
195195

196-
def test_init_with_data_single_objective(self) -> None:
197-
self._test_init_with_data(multi_objective=False)
198-
199-
def test_init_with_data_multi_objective(self) -> None:
200-
self._test_init_with_data(multi_objective=True)
196+
def test_init_with_data(self) -> None:
197+
for multi_objective in (False, True):
198+
with self.subTest(multi_objective=multi_objective):
199+
self._test_init_with_data(multi_objective=multi_objective)
201200

202201
def test_fit_tracking_metrics(self) -> None:
203202
# Test error when fit_tracking_metrics is False and optimization

ax/adapter/tests/test_hierarchical_search_space.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,16 @@ def _base_test_predict_and_cv(
220220
cv_res = cross_validate(adapter=mbm)
221221
self.assertEqual(len(cv_res), len(experiment.trials))
222222

223-
def test_with_non_hierarchical_hss(self) -> None:
224-
experiment = self._test_gen_base(
225-
hss=self.non_hierarchical_hss, expected_num_candidate_params=[3]
226-
)
227-
self._base_test_predict_and_cv(experiment=experiment)
228-
229-
def test_with_simple_hss(self) -> None:
230-
experiment = self._test_gen_base(
231-
hss=self.simple_hss, expected_num_candidate_params=[2]
232-
)
233-
self._base_test_predict_and_cv(experiment=experiment)
234-
235-
def test_with_complex_hss(self) -> None:
236-
experiment = self._test_gen_base(
237-
hss=self.complex_hss, expected_num_candidate_params=[2, 4, 5]
238-
)
239-
self._base_test_predict_and_cv(experiment=experiment)
223+
def test_with_hss_variants(self) -> None:
224+
cases = [
225+
("non_hierarchical", self.non_hierarchical_hss, [3]),
226+
("simple", self.simple_hss, [2]),
227+
("complex", self.complex_hss, [2, 4, 5]),
228+
]
229+
for label, hss, expected_num_candidate_params in cases:
230+
with self.subTest(hss_variant=label):
231+
experiment = self._test_gen_base(
232+
hss=hss,
233+
expected_num_candidate_params=expected_num_candidate_params,
234+
)
235+
self._base_test_predict_and_cv(experiment=experiment)

ax/adapter/tests/test_torch_adapter.py

Lines changed: 50 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,90 +1251,35 @@ def test_pairwise_preference_generator(self) -> None:
12511251
X=X.expand(2, *X.shape), Y=comp_pair_Y.expand(2, *comp_pair_Y.shape)
12521252
)
12531253

1254-
def test_get_transformed_model_gen_args_with_target_point(self) -> None:
1255-
# Test that _get_transformed_model_gen_args correctly processes target_point
1256-
1257-
# Setup: create adapter with target arm in optimization config
1258-
experiment = get_branin_experiment(with_completed_trial=True)
1259-
pruning_target_parameterization = Arm(parameters={"x1": -5.0, "x2": 15.0})
1260-
optimization_config = none_throws(
1261-
experiment.optimization_config
1262-
).clone_with_args(
1263-
pruning_target_parameterization=pruning_target_parameterization
1264-
)
1265-
1266-
adapter = TorchAdapter(
1267-
generator=TorchGenerator(),
1268-
experiment=experiment,
1269-
transforms=Cont_X_trans,
1270-
)
1271-
1272-
# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
1273-
base_gen_args = adapter._get_transformed_gen_args(
1274-
search_space=experiment.search_space,
1275-
optimization_config=optimization_config,
1276-
pending_observations={},
1277-
)
1278-
1279-
search_space_digest, torch_opt_config = adapter._get_transformed_model_gen_args(
1280-
search_space=base_gen_args.search_space,
1281-
pending_observations=base_gen_args.pending_observations,
1282-
fixed_features=base_gen_args.fixed_features,
1283-
optimization_config=base_gen_args.optimization_config,
1284-
)
1285-
1286-
# Assert: confirm pruning_target_point is correctly extracted and transformed
1287-
self.assertIsNotNone(torch_opt_config.pruning_target_point)
1288-
expected_target = torch.tensor([0.0, 1.0], dtype=torch.double)
1289-
torch.testing.assert_close(
1290-
torch_opt_config.pruning_target_point, expected_target
1254+
def _test_get_transformed_model_gen_args_target_point(
1255+
self,
1256+
with_status_quo: bool,
1257+
pruning_target_params: dict[str, float] | None,
1258+
expected_target: torch.Tensor | None,
1259+
) -> None:
1260+
experiment = get_branin_experiment(
1261+
with_completed_trial=True,
1262+
with_status_quo=with_status_quo,
12911263
)
12921264

1293-
def test_get_transformed_model_gen_args_no_target_point(self) -> None:
1294-
# Test that _get_transformed_model_gen_args handles
1295-
# pruning_target_parameterization=None correctly
1265+
opt_config = none_throws(experiment.optimization_config)
1266+
if pruning_target_params is not None:
1267+
pruning_target = Arm(parameters=pruning_target_params)
1268+
opt_config = opt_config.clone_with_args(
1269+
pruning_target_parameterization=pruning_target
1270+
)
1271+
elif with_status_quo:
1272+
opt_config = opt_config.clone()
12961273

1297-
# Setup: create adapter without target arm (default case)
1298-
experiment = get_branin_experiment(with_completed_trial=True)
12991274
adapter = TorchAdapter(
13001275
generator=TorchGenerator(),
13011276
experiment=experiment,
13021277
transforms=Cont_X_trans,
13031278
)
13041279

1305-
# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
13061280
base_gen_args = adapter._get_transformed_gen_args(
13071281
search_space=experiment.search_space,
1308-
optimization_config=none_throws(experiment.optimization_config),
1309-
pending_observations={},
1310-
)
1311-
1312-
search_space_digest, torch_opt_config = adapter._get_transformed_model_gen_args(
1313-
search_space=base_gen_args.search_space,
1314-
pending_observations=base_gen_args.pending_observations,
1315-
fixed_features=base_gen_args.fixed_features,
1316-
optimization_config=base_gen_args.optimization_config,
1317-
)
1318-
1319-
# Assert: confirm target_point is None when no pruning_target_parameterization
1320-
# is provided
1321-
self.assertIsNone(torch_opt_config.pruning_target_point)
1322-
1323-
def test_get_transformed_model_gen_args_with_sq_as_target(self) -> None:
1324-
# Test that _get_transformed_model_gen_args correctly processes the status quo
1325-
# as the target point
1326-
experiment = get_branin_experiment(
1327-
with_completed_trial=True, with_status_quo=True
1328-
)
1329-
1330-
adapter = TorchAdapter(
1331-
generator=TorchGenerator(), experiment=experiment, transforms=Cont_X_trans
1332-
)
1333-
oc = none_throws(experiment.optimization_config).clone()
1334-
# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
1335-
base_gen_args = adapter._get_transformed_gen_args(
1336-
search_space=experiment.search_space,
1337-
optimization_config=oc,
1282+
optimization_config=opt_config,
13381283
pending_observations={},
13391284
)
13401285

@@ -1345,12 +1290,38 @@ def test_get_transformed_model_gen_args_with_sq_as_target(self) -> None:
13451290
optimization_config=base_gen_args.optimization_config,
13461291
)
13471292

1348-
# Assert: confirm pruning_target_point is correctly extracted and transformed
1349-
self.assertIsNotNone(torch_opt_config.pruning_target_point)
1350-
expected_target = torch.tensor([1 / 3.0, 0.0], dtype=torch.double)
1351-
torch.testing.assert_close(
1352-
torch_opt_config.pruning_target_point, expected_target
1353-
)
1293+
if expected_target is None:
1294+
self.assertIsNone(torch_opt_config.pruning_target_point)
1295+
else:
1296+
self.assertIsNotNone(torch_opt_config.pruning_target_point)
1297+
torch.testing.assert_close(
1298+
torch_opt_config.pruning_target_point,
1299+
expected_target,
1300+
)
1301+
1302+
def test_get_transformed_model_gen_args_target_point(self) -> None:
1303+
# Test _get_transformed_model_gen_args with various target point scenarios
1304+
for label, with_status_quo, pruning_target_params, expected_target in [
1305+
(
1306+
"with_target_point",
1307+
False,
1308+
{"x1": -5.0, "x2": 15.0},
1309+
torch.tensor([0.0, 1.0], dtype=torch.double),
1310+
),
1311+
("no_target_point", False, None, None),
1312+
(
1313+
"sq_as_target",
1314+
True,
1315+
None,
1316+
torch.tensor([1 / 3.0, 0.0], dtype=torch.double),
1317+
),
1318+
]:
1319+
with self.subTest(scenario=label):
1320+
self._test_get_transformed_model_gen_args_target_point(
1321+
with_status_quo=with_status_quo,
1322+
pruning_target_params=pruning_target_params,
1323+
expected_target=expected_target,
1324+
)
13541325

13551326
@mock_botorch_optimize
13561327
def test_moo_with_derived_parameter(self) -> None:

ax/adapter/transforms/tests/test_logit_transform.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,11 @@ def test_InvalidSettings(self) -> None:
9595
self.assertEqual("x can't use both log and logit.", str(cm.exception))
9696

9797
str_exc = "x logit requires lower > 0 and upper < 1"
98-
with self.assertRaises(UserInputError) as cm:
99-
self._create_logit_parameter(lower=0.0, upper=0.5)
100-
self.assertEqual(str_exc, str(cm.exception))
101-
with self.assertRaises(UserInputError) as cm:
102-
self._create_logit_parameter(lower=0.3, upper=1.0)
103-
self.assertEqual(str_exc, str(cm.exception))
104-
with self.assertRaises(UserInputError) as cm:
105-
self._create_logit_parameter(lower=0.5, upper=10.0)
106-
self.assertEqual(str_exc, str(cm.exception))
98+
for lower, upper in [(0.0, 0.5), (0.3, 1.0), (0.5, 10.0)]:
99+
with self.subTest(lower=lower, upper=upper):
100+
with self.assertRaises(UserInputError) as cm:
101+
self._create_logit_parameter(lower=lower, upper=upper)
102+
self.assertEqual(str_exc, str(cm.exception))
107103

108104
def test_TransformSearchSpace(self) -> None:
109105
ss2 = deepcopy(self.search_space)

ax/adapter/transforms/tests/test_map_key_to_float_transform.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,10 @@ def _test_early_stopping(self, complete_with_progression: bool) -> None:
249249
# Check that cross validation works.
250250
cross_validate(adapter=adapter)
251251

252-
def test_no_early_stopping_with_progression(self) -> None:
253-
self._test_no_early_stopping(with_progression=True)
254-
255-
def test_no_early_stopping_no_progression(self) -> None:
256-
self._test_no_early_stopping(with_progression=False)
252+
def test_no_early_stopping(self) -> None:
253+
for with_progression in (True, False):
254+
with self.subTest(with_progression=with_progression):
255+
self._test_no_early_stopping(with_progression=with_progression)
257256

258257
def test_early_stopping_with_final_progression(self) -> None:
259258
self._test_early_stopping(complete_with_progression=True)

ax/adapter/transforms/tests/test_objective_as_constraint.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -334,37 +334,25 @@ def test_relative_constraint_feasibility_check(self) -> None:
334334

335335
def test_leq_constraint_feasibility(self) -> None:
336336
"""Test feasibility checking with LEQ constraints."""
337-
# m2 <= 0.3 constraint. Both observations have m2 > 0.3, so infeasible.
338-
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
339-
observations=[[1.0, 0.5], [2.0, 5.0]],
340-
constraint_bound=0.3,
341-
constraint_op=ComparisonOp.LEQ,
342-
)
343-
344-
t = ObjectiveAsConstraint(
345-
search_space=adapter._experiment.search_space,
346-
experiment_data=experiment_data,
347-
adapter=adapter,
348-
)
349-
350-
self.assertTrue(t._should_add_constraint)
351-
352-
def test_leq_constraint_feasible(self) -> None:
353-
"""Test that LEQ constraints with feasible points are correctly detected."""
354-
# m2 <= 10.0 constraint. Both observations have m2 <= 10.0, so feasible.
355-
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
356-
observations=[[1.0, 0.5], [2.0, 5.0]],
357-
constraint_bound=10.0,
358-
constraint_op=ComparisonOp.LEQ,
359-
)
337+
cases = [
338+
(0.3, True, "infeasible"),
339+
(10.0, False, "feasible"),
340+
]
341+
for bound, expected_should_add, label in cases:
342+
with self.subTest(bound=bound, scenario=label):
343+
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
344+
observations=[[1.0, 0.5], [2.0, 5.0]],
345+
constraint_bound=bound,
346+
constraint_op=ComparisonOp.LEQ,
347+
)
360348

361-
t = ObjectiveAsConstraint(
362-
search_space=adapter._experiment.search_space,
363-
experiment_data=experiment_data,
364-
adapter=adapter,
365-
)
349+
t = ObjectiveAsConstraint(
350+
search_space=adapter._experiment.search_space,
351+
experiment_data=experiment_data,
352+
adapter=adapter,
353+
)
366354

367-
self.assertFalse(t._should_add_constraint)
355+
self.assertEqual(t._should_add_constraint, expected_should_add)
368356

369357
def test_no_op_for_experiment_data(self) -> None:
370358
"""Test that transform_experiment_data is a no-op."""

0 commit comments

Comments
 (0)