Skip to content

Commit f588404

Browse files
Eric Onofreymeta-codesync[bot]
authored andcommitted
Consolidate ax/api and ax/benchmark tests
Summary: Part of a 19-diff stack to consolidate repetitive tests across Ax and PTS using `subTest`. Consolidate 11 test files in ax/api/ and ax/benchmark/ — adds subTest to benchmark result, benchmark metric, and API utils tests. Differential Revision: D95606935
1 parent 5ecd8a9 commit f588404

11 files changed

Lines changed: 182 additions & 164 deletions

ax/api/utils/instantiation/tests/test_from_config.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -349,21 +349,16 @@ def test_experiment_from_config(self) -> None:
349349
)
350350

351351
def test_parameter_type_converter(self) -> None:
352-
self.assertEqual(
353-
_parameter_type_converter(parameter_type="bool"),
354-
CoreParameterType.BOOL,
355-
)
356-
self.assertEqual(
357-
_parameter_type_converter(parameter_type="int"),
358-
CoreParameterType.INT,
359-
)
360-
self.assertEqual(
361-
_parameter_type_converter(parameter_type="float"),
362-
CoreParameterType.FLOAT,
363-
)
364-
self.assertEqual(
365-
_parameter_type_converter(parameter_type="str"),
366-
CoreParameterType.STRING,
367-
)
352+
for type_str, expected in [
353+
("bool", CoreParameterType.BOOL),
354+
("int", CoreParameterType.INT),
355+
("float", CoreParameterType.FLOAT),
356+
("str", CoreParameterType.STRING),
357+
]:
358+
with self.subTest(parameter_type=type_str):
359+
self.assertEqual(
360+
_parameter_type_converter(parameter_type=type_str),
361+
expected,
362+
)
368363
with self.assertRaisesRegex(UserInputError, "Unsupported parameter type"):
369364
_parameter_type_converter(parameter_type="bad")

ax/benchmark/tests/problems/surrogate/hss/test_cifar10_surrogate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def test_cifar10_surrogate(self) -> None:
6868
benchmark = get_cifar10_surrogate_benchmark(num_trials=1)
6969

7070
for params, target_value in cases:
71-
self.assertAlmostEqual(
72-
benchmark.test_function.evaluate_true(params).item(),
73-
target_value,
74-
)
71+
with self.subTest(params=params):
72+
self.assertAlmostEqual(
73+
benchmark.test_function.evaluate_true(params).item(),
74+
target_value,
75+
)

ax/benchmark/tests/problems/surrogate/hss/test_fashion_mnist_surrogate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def test_fashion_mnist_surrogate(self) -> None:
6868
benchmark = get_fashion_mnist_surrogate_benchmark(num_trials=1)
6969

7070
for params, target_value in cases:
71-
self.assertAlmostEqual(
72-
benchmark.test_function.evaluate_true(params).item(),
73-
target_value,
74-
)
71+
with self.subTest(params=params):
72+
self.assertAlmostEqual(
73+
benchmark.test_function.evaluate_true(params).item(),
74+
target_value,
75+
)

ax/benchmark/tests/problems/surrogate/hss/test_mnist_surrogate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ def test_mnist_surrogate(self) -> None:
5252
benchmark = get_mnist_surrogate_benchmark(num_trials=1)
5353

5454
for params, target_value in cases:
55-
self.assertAlmostEqual(
56-
benchmark.test_function.evaluate_true(params).item(),
57-
target_value,
58-
)
55+
with self.subTest(params=params):
56+
self.assertAlmostEqual(
57+
benchmark.test_function.evaluate_true(params).item(),
58+
target_value,
59+
)
5960

6061
def test_benchmark_creation(self) -> None:
6162
benchmark = get_mnist_surrogate_benchmark(num_trials=1)

ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,20 @@ def test_jenatton_test_function(self) -> None:
7878
)
7979

8080
for params, value in cases:
81-
self.assertAlmostEqual(
82-
# pyre-fixme: Incompatible parameter type [6]: In call
83-
# `jenatton_test_function`, for 1st positional argument,
84-
# expected `Optional[float]` but got `Union[None, bool, float,
85-
# int, str]`.
86-
jenatton_test_function(**params),
87-
value,
88-
)
89-
self.assertAlmostEqual(
90-
benchmark_problem.test_function.evaluate_true(params=params).item(),
91-
value,
92-
places=6,
93-
)
81+
with self.subTest(params=params, expected_value=value):
82+
self.assertAlmostEqual(
83+
# pyre-fixme: Incompatible parameter type [6]: In call
84+
# `jenatton_test_function`, for 1st positional argument,
85+
# expected `Optional[float]` but got `Union[None, bool, float,
86+
# int, str]`.
87+
jenatton_test_function(**params),
88+
value,
89+
)
90+
self.assertAlmostEqual(
91+
benchmark_problem.test_function.evaluate_true(params=params).item(),
92+
value,
93+
places=6,
94+
)
9495

9596
def test_create_problem(self) -> None:
9697
problem = get_jenatton_benchmark_problem()

ax/benchmark/tests/problems/synthetic/test_from_botorch.py

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -74,59 +74,65 @@ def test_get_augmented_branin_problem(self) -> None:
7474
class TestFromBoTorch(TestCase):
7575
def test_single_objective_from_botorch(self) -> None:
7676
for botorch_test_problem in [Ackley(), ConstrainedHartmann(dim=6)]:
77-
test_problem = create_problem_from_botorch(
78-
test_problem_class=botorch_test_problem.__class__,
79-
test_problem_kwargs={},
80-
num_trials=1,
81-
baseline_value=100.0,
82-
)
83-
84-
# Test search space
85-
self.assertEqual(
86-
len(test_problem.search_space.parameters), botorch_test_problem.dim
87-
)
88-
self.assertEqual(
89-
len(test_problem.search_space.parameters),
90-
len(test_problem.search_space.range_parameters),
91-
)
92-
self.assertTrue(
93-
all(
94-
test_problem.search_space.range_parameters[f"x{i}"].lower
95-
== botorch_test_problem._bounds[i][0]
96-
for i in range(botorch_test_problem.dim)
97-
),
98-
"Parameters' lower bounds must all match Botorch problem's bounds.",
99-
)
100-
self.assertTrue(
101-
all(
102-
test_problem.search_space.range_parameters[f"x{i}"].upper
103-
== botorch_test_problem._bounds[i][1]
104-
for i in range(botorch_test_problem.dim)
105-
),
106-
"Parameters' upper bounds must all match Botorch problem's bounds.",
107-
)
77+
with self.subTest(problem=botorch_test_problem.__class__.__name__):
78+
test_problem = create_problem_from_botorch(
79+
test_problem_class=botorch_test_problem.__class__,
80+
test_problem_kwargs={},
81+
num_trials=1,
82+
baseline_value=100.0,
83+
)
10884

109-
# Test optimum
110-
self.assertEqual(
111-
test_problem.optimal_value, botorch_test_problem._optimal_value
112-
)
113-
# test optimization config
114-
metric_name = test_problem.optimization_config.objective.metric.name
115-
self.assertEqual(metric_name, test_problem.name)
116-
self.assertTrue(test_problem.optimization_config.objective.minimize)
117-
# test repr method
118-
if isinstance(botorch_test_problem, Ackley):
85+
# Test search space
86+
self.assertEqual(
87+
len(test_problem.search_space.parameters),
88+
botorch_test_problem.dim,
89+
)
11990
self.assertEqual(
120-
test_problem.optimization_config.outcome_constraints, []
91+
len(test_problem.search_space.parameters),
92+
len(test_problem.search_space.range_parameters),
12193
)
122-
else:
123-
outcome_constraint = (
124-
test_problem.optimization_config.outcome_constraints[0]
94+
self.assertTrue(
95+
all(
96+
test_problem.search_space.range_parameters[f"x{i}"].lower
97+
== botorch_test_problem._bounds[i][0]
98+
for i in range(botorch_test_problem.dim)
99+
),
100+
"Parameters' lower bounds must all match Botorch problem's bounds.",
101+
)
102+
self.assertTrue(
103+
all(
104+
test_problem.search_space.range_parameters[f"x{i}"].upper
105+
== botorch_test_problem._bounds[i][1]
106+
for i in range(botorch_test_problem.dim)
107+
),
108+
"Parameters' upper bounds must all match Botorch problem's bounds.",
109+
)
110+
111+
# Test optimum
112+
self.assertEqual(
113+
test_problem.optimal_value,
114+
botorch_test_problem._optimal_value,
125115
)
126-
self.assertEqual(outcome_constraint.metric.name, "constraint_slack_0")
127-
self.assertEqual(outcome_constraint.op, ComparisonOp.GEQ)
128-
self.assertFalse(outcome_constraint.relative)
129-
self.assertEqual(outcome_constraint.bound, 0.0)
116+
# test optimization config
117+
metric_name = test_problem.optimization_config.objective.metric.name
118+
self.assertEqual(metric_name, test_problem.name)
119+
self.assertTrue(test_problem.optimization_config.objective.minimize)
120+
# test repr method
121+
if isinstance(botorch_test_problem, Ackley):
122+
self.assertEqual(
123+
test_problem.optimization_config.outcome_constraints,
124+
[],
125+
)
126+
else:
127+
outcome_constraint = (
128+
test_problem.optimization_config.outcome_constraints[0]
129+
)
130+
self.assertEqual(
131+
outcome_constraint.metric.name, "constraint_slack_0"
132+
)
133+
self.assertEqual(outcome_constraint.op, ComparisonOp.GEQ)
134+
self.assertFalse(outcome_constraint.relative)
135+
self.assertEqual(outcome_constraint.bound, 0.0)
130136

131137
def _test_constrained_from_botorch(
132138
self,
@@ -259,8 +265,9 @@ def _test_moo_from_botorch(self, lower_is_better: bool) -> None:
259265
)
260266

261267
def test_moo_from_botorch(self) -> None:
262-
self._test_moo_from_botorch(lower_is_better=True)
263-
self._test_moo_from_botorch(lower_is_better=False)
268+
for lower_is_better in [True, False]:
269+
with self.subTest(lower_is_better=lower_is_better):
270+
self._test_moo_from_botorch(lower_is_better=lower_is_better)
264271

265272
def test_create_problem_from_botorch_with_shifted_function(self) -> None:
266273
ax_problem = create_problem_from_botorch(

ax/benchmark/tests/problems/test_mixed_integer_problems.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,31 @@ def test_problems(self) -> None:
2929
(Rosenbrock, get_discrete_rosenbrock, 10, 6),
3030
):
3131
name = problem_cls.__name__
32-
problem = constructor()
33-
self.assertEqual(f"Discrete {name}", problem.name)
34-
test_function = assert_is_instance(
35-
problem.test_function, BoTorchTestFunction
36-
)
37-
botorch_problem = test_function.botorch_problem
38-
self.assertIsInstance(botorch_problem, problem_cls)
39-
self.assertEqual(len(problem.search_space.parameters), dim)
40-
self.assertEqual(
41-
sum(
42-
p.parameter_type == ParameterType.INT
43-
for p in problem.search_space.parameters.values()
44-
),
45-
dim_int,
46-
)
47-
# Check that the underlying problem has the correct bounds.
48-
if name == "Rosenbrock":
49-
expected_bounds = [(-5.0, 10.0) for _ in range(dim)]
50-
else:
51-
expected_bounds = [(0.0, 1.0) for _ in range(dim)]
52-
self.assertEqual(botorch_problem._bounds, expected_bounds)
53-
self.assertGreaterEqual(problem.optimal_value, problem_cls().optimal_value)
32+
with self.subTest(problem=name):
33+
problem = constructor()
34+
self.assertEqual(f"Discrete {name}", problem.name)
35+
test_function = assert_is_instance(
36+
problem.test_function, BoTorchTestFunction
37+
)
38+
botorch_problem = test_function.botorch_problem
39+
self.assertIsInstance(botorch_problem, problem_cls)
40+
self.assertEqual(len(problem.search_space.parameters), dim)
41+
self.assertEqual(
42+
sum(
43+
p.parameter_type == ParameterType.INT
44+
for p in problem.search_space.parameters.values()
45+
),
46+
dim_int,
47+
)
48+
# Check that the underlying problem has the correct bounds.
49+
if name == "Rosenbrock":
50+
expected_bounds = [(-5.0, 10.0) for _ in range(dim)]
51+
else:
52+
expected_bounds = [(0.0, 1.0) for _ in range(dim)]
53+
self.assertEqual(botorch_problem._bounds, expected_bounds)
54+
self.assertGreaterEqual(
55+
problem.optimal_value, problem_cls().optimal_value
56+
)
5457

5558
# Test that they match correctly to the original problems.
5659
cases: list[tuple[BenchmarkProblem, dict[str, float], torch.Tensor]] = [
@@ -94,14 +97,15 @@ def test_problems(self) -> None:
9497
]
9598

9699
for problem, params, expected_arg in cases:
97-
test_function = assert_is_instance(
98-
problem.test_function, BoTorchTestFunction
99-
)
100-
with patch.object(
101-
test_function.botorch_problem,
102-
attribute="evaluate_true",
103-
wraps=test_function.botorch_problem.evaluate_true,
104-
) as mock_call:
105-
test_function.evaluate_true(params=params)
106-
actual = mock_call.call_args.kwargs["X"]
107-
self.assertAllClose(actual, expected_arg)
100+
with self.subTest(problem=problem.name, params=params):
101+
test_function = assert_is_instance(
102+
problem.test_function, BoTorchTestFunction
103+
)
104+
with patch.object(
105+
test_function.botorch_problem,
106+
attribute="evaluate_true",
107+
wraps=test_function.botorch_problem.evaluate_true,
108+
) as mock_call:
109+
test_function.evaluate_true(params=params)
110+
actual = mock_call.call_args.kwargs["X"]
111+
self.assertAllClose(actual, expected_arg)

ax/benchmark/tests/problems/test_problems.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def test_name(self) -> None:
6464
for name in ["Discrete Ackley", "Discrete Hartmann", "Discrete Rosenbrock"]
6565
]
6666
for registry_key, problem_name in expected_names:
67-
problem = get_benchmark_problem(problem_key=registry_key)
68-
self.assertEqual(problem.name, problem_name)
67+
with self.subTest(registry_key=registry_key):
68+
problem = get_benchmark_problem(problem_key=registry_key)
69+
self.assertEqual(problem.name, problem_name)
6970

7071
def test_no_duplicates(self) -> None:
7172
problem_names = set()

ax/benchmark/tests/test_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ def _test_storage(self, map_data: bool) -> None:
189189
self.assertEqual(experiment, experiment)
190190

191191
def test_storage(self) -> None:
192-
self._test_storage(map_data=False)
193-
self._test_storage(map_data=True)
192+
for map_data in [False, True]:
193+
with self.subTest(map_data=map_data):
194+
self._test_storage(map_data=map_data)
194195

195196
def test_replication_sobol_synthetic(self) -> None:
196197
method = get_sobol_benchmark_method()

ax/benchmark/tests/test_benchmark_metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N
341341
self.assertEqual(df.reset_index(drop=True).to_dict(), expected_df.to_dict())
342342

343343
def test_fetch_trial_multiple_time_steps_with_simulator(self) -> None:
344-
self._test_fetch_trial_multiple_time_steps_with_simulator(batch=False)
345-
self._test_fetch_trial_multiple_time_steps_with_simulator(batch=True)
344+
for batch in [False, True]:
345+
with self.subTest(batch=batch):
346+
self._test_fetch_trial_multiple_time_steps_with_simulator(batch=batch)
346347

347348
def test_sim_trial_completes_in_future_raises(self) -> None:
348349
simulator = BackendSimulator()

0 commit comments

Comments
 (0)