Skip to content

Commit 70ad75c

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Consolidate ax/generator tests
Summary: Part of a 19-diff stack to consolidate repetitive tests across Ax using `subTest`. Consolidate 8 test files in ax/generators/ — adds subTest to Thompson sampler weight configs, kernel tests, covariance module argparse, and generator utils tests. Differential Revision: D95604917
1 parent abedf82 commit 70ad75c

8 files changed

Lines changed: 314 additions & 284 deletions

File tree

ax/generators/tests/test_discrete.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ def test_discrete_model_get_state(self) -> None:
1616
discrete_model = DiscreteGenerator()
1717
self.assertEqual(discrete_model._get_state(), {})
1818

19-
def test_discrete_model_feature_importances(self) -> None:
20-
discrete_model = DiscreteGenerator()
21-
with self.assertRaises(NotImplementedError):
22-
discrete_model.feature_importances()
23-
2419
def test_DiscreteGeneratorFit(self) -> None:
2520
discrete_model = DiscreteGenerator()
2621
discrete_model.fit(
@@ -38,16 +33,18 @@ def test_discreteModelPredict(self) -> None:
3833
with self.assertRaises(NotImplementedError):
3934
discrete_model.predict([[0]])
4035

41-
def test_discreteModelGen(self) -> None:
36+
def test_not_implemented_methods(self) -> None:
4237
discrete_model = DiscreteGenerator()
43-
with self.assertRaises(NotImplementedError):
44-
discrete_model.gen(
38+
cases = {
39+
"feature_importances": lambda: discrete_model.feature_importances(),
40+
"gen": lambda: discrete_model.gen(
4541
n=1, parameter_values=[[0, 1]], objective_weights=np.array([1])
46-
)
47-
48-
def test_discreteModelCrossValidate(self) -> None:
49-
discrete_model = DiscreteGenerator()
50-
with self.assertRaises(NotImplementedError):
51-
discrete_model.cross_validate(
42+
),
43+
"cross_validate": lambda: discrete_model.cross_validate(
5244
Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]]
53-
)
45+
),
46+
}
47+
for method_name, call in cases.items():
48+
with self.subTest(method=method_name):
49+
with self.assertRaises(NotImplementedError):
50+
call()

ax/generators/tests/test_random.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,24 @@ def test_seed(self) -> None:
3131

3232
def test_state(self) -> None:
3333
for model in (self.random_model, RandomGenerator(seed=5)):
34-
state = model._get_state()
35-
self.assertEqual(state["seed"], model.seed)
36-
self.assertEqual(state["init_position"], model.init_position)
34+
with self.subTest(seed=model.seed):
35+
state = model._get_state()
36+
self.assertEqual(state["seed"], model.seed)
37+
self.assertEqual(state["init_position"], model.init_position)
3738

38-
def test_RandomGeneratorGenSamples(self) -> None:
39-
with self.assertRaises(NotImplementedError):
40-
self.random_model._gen_samples(
39+
def test_not_implemented_methods(self) -> None:
40+
cases = {
41+
"_gen_samples": lambda: self.random_model._gen_samples(
4142
n=1, tunable_d=1, bounds=np.array([[0.0, 1.0]])
42-
)
43-
44-
def test_RandomGeneratorGenUnconstrained(self) -> None:
45-
with self.assertRaises(NotImplementedError):
46-
self.random_model._gen_unconstrained(
43+
),
44+
"_gen_unconstrained": lambda: self.random_model._gen_unconstrained(
4745
n=1, d=2, tunable_feature_indices=np.array([], dtype=int)
48-
)
46+
),
47+
}
48+
for method_name, call in cases.items():
49+
with self.subTest(method=method_name):
50+
with self.assertRaises(NotImplementedError):
51+
call()
4952

5053
def test_ConvertEqualityConstraints(self) -> None:
5154
fixed_features = {3: 0.7, 1: 0.5}

ax/generators/tests/test_thompson.py

Lines changed: 72 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -57,41 +57,80 @@ def test_ThompsonSampler(self) -> None:
5757
self.assertEqual(len(gen_metadata["arms_to_weights"]), 4)
5858
self.assertEqual(gen_metadata["best_x"], arms[0])
5959

60+
def test_ThompsonSamplerWeightConfigs(self) -> None:
61+
for label, min_weight, uniform_weights, expected_arms, expected_weights in [
62+
(
63+
"min_weight=0.01",
64+
0.01,
65+
False,
66+
[[4, 4], [3, 3], [2, 2]],
67+
[3 * i for i in [0.725, 0.225, 0.05]],
68+
),
69+
(
70+
"uniform_weights",
71+
0.0,
72+
True,
73+
[[4, 4], [3, 3], [2, 2]],
74+
[1.0, 1.0, 1.0],
75+
),
76+
]:
77+
with self.subTest(config=label):
78+
np.random.seed(0)
79+
generator = ThompsonSampler(
80+
min_weight=min_weight, uniform_weights=uniform_weights
81+
)
82+
generator.fit(
83+
Xs=self.Xs,
84+
Ys=self.Ys,
85+
Yvars=self.Yvars,
86+
parameter_values=self.parameter_values,
87+
outcome_names=self.outcome_names,
88+
)
89+
arms, weights, _ = generator.gen(
90+
n=3,
91+
parameter_values=self.parameter_values,
92+
objective_weights=np.ones(1),
93+
)
94+
self.assertEqual(arms, expected_arms)
95+
for weight, expected_weight in zip(weights, expected_weights):
96+
self.assertAlmostEqual(weight, expected_weight, 1)
97+
6098
def test_ThompsonSamplerValidation(self) -> None:
6199
generator = ThompsonSampler(min_weight=0.01)
62100

63-
# all Xs are not the same
64-
with self.assertRaises(ValueError):
65-
generator.fit(
66-
Xs=[[[1, 1], [2, 2], [3, 3], [4, 4]], [[1, 1], [2, 2], [4, 4]]],
67-
Ys=self.Ys,
68-
Yvars=self.Yvars,
69-
parameter_values=self.parameter_values,
70-
outcome_names=self.outcome_names,
71-
)
101+
with self.subTest(case="mismatched_Xs"):
102+
with self.assertRaises(ValueError):
103+
generator.fit(
104+
Xs=[[[1, 1], [2, 2], [3, 3], [4, 4]], [[1, 1], [2, 2], [4, 4]]],
105+
Ys=self.Ys,
106+
Yvars=self.Yvars,
107+
parameter_values=self.parameter_values,
108+
outcome_names=self.outcome_names,
109+
)
72110

73-
# multiple observations per parameterization
74-
with self.assertRaises(ValueError):
111+
with self.subTest(case="duplicate_parameterizations"):
112+
with self.assertRaises(ValueError):
113+
generator.fit(
114+
Xs=[[[1, 1], [2, 2], [2, 2]]],
115+
Ys=self.Ys,
116+
Yvars=self.Yvars,
117+
parameter_values=self.parameter_values,
118+
outcome_names=self.outcome_names,
119+
)
120+
121+
with self.subTest(case="similar_but_different_observations"):
122+
# these are not the same observations, so should not error
75123
generator.fit(
76-
Xs=[[[1, 1], [2, 2], [2, 2]]],
124+
Xs=[[[1, 1], [2.0, 2], [2, 2]]],
77125
Ys=self.Ys,
78126
Yvars=self.Yvars,
79127
parameter_values=self.parameter_values,
80128
outcome_names=self.outcome_names,
81129
)
82130

83-
# these are not the same observations, so should not error
84-
generator.fit(
85-
Xs=[[[1, 1], [2.0, 2], [2, 2]]],
86-
Ys=self.Ys,
87-
Yvars=self.Yvars,
88-
parameter_values=self.parameter_values,
89-
outcome_names=self.outcome_names,
90-
)
91-
92-
# requires objective weights
93-
with self.assertRaises(ValueError):
94-
generator.gen(5, self.parameter_values, objective_weights=None)
131+
with self.subTest(case="missing_objective_weights"):
132+
with self.assertRaises(ValueError):
133+
generator.gen(5, self.parameter_values, objective_weights=None)
95134

96135
def test_ThompsonSamplerTopKError(self) -> None:
97136
generator = ThompsonSampler(topk=5)
@@ -156,45 +195,6 @@ def test_TopTwo_alters_weights_vs_TopOne(self) -> None:
156195
# 4) Monotonicity in the final TTTS distribution still holds
157196
self.assertTrue(full_w2[3] > full_w2[2] > full_w2[1] > full_w2[0])
158197

159-
def test_ThompsonSamplerMinWeight(self) -> None:
160-
np.random.seed(0)
161-
generator = ThompsonSampler(min_weight=0.01)
162-
generator.fit(
163-
Xs=self.Xs,
164-
Ys=self.Ys,
165-
Yvars=self.Yvars,
166-
parameter_values=self.parameter_values,
167-
outcome_names=self.outcome_names,
168-
)
169-
arms, weights, _ = generator.gen(
170-
n=3,
171-
parameter_values=self.parameter_values,
172-
objective_weights=np.ones(1),
173-
)
174-
self.assertEqual(arms, [[4, 4], [3, 3], [2, 2]])
175-
for weight, expected_weight in zip(
176-
weights, [3 * i for i in [0.725, 0.225, 0.05]]
177-
):
178-
self.assertAlmostEqual(weight, expected_weight, 1)
179-
180-
def test_ThompsonSamplerUniformWeights(self) -> None:
181-
generator = ThompsonSampler(min_weight=0.0, uniform_weights=True)
182-
generator.fit(
183-
Xs=self.Xs,
184-
Ys=self.Ys,
185-
Yvars=self.Yvars,
186-
parameter_values=self.parameter_values,
187-
outcome_names=self.outcome_names,
188-
)
189-
arms, weights, _ = generator.gen(
190-
n=3,
191-
parameter_values=self.parameter_values,
192-
objective_weights=np.ones(1),
193-
)
194-
self.assertEqual(arms, [[4, 4], [3, 3], [2, 2]])
195-
for weight, expected_weight in zip(weights, [1.0, 1.0, 1.0]):
196-
self.assertAlmostEqual(weight, expected_weight, 1)
197-
198198
def test_ThompsonSamplerInfeasible(self) -> None:
199199
generator = ThompsonSampler(min_weight=0.9)
200200
generator.fit(
@@ -302,9 +302,12 @@ def test_ThompsonSamplerNonPositiveN(self) -> None:
302302
outcome_names=self.outcome_names,
303303
)
304304
for n in (-1, 0):
305-
with self.assertRaisesRegex(ValueError, "ThompsonSampler requires n > 0"):
306-
generator.gen(
307-
n=n,
308-
parameter_values=self.parameter_values,
309-
objective_weights=np.ones(1),
310-
)
305+
with self.subTest(n=n):
306+
with self.assertRaisesRegex(
307+
ValueError, "ThompsonSampler requires n > 0"
308+
):
309+
generator.gen(
310+
n=n,
311+
parameter_values=self.parameter_values,
312+
objective_weights=np.ones(1),
313+
)

ax/generators/tests/test_torch.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,25 @@ def test_TorchModelFit(self) -> None:
3939
),
4040
)
4141

42-
def test_TorchModelPredict(self) -> None:
42+
def test_not_implemented_methods(self) -> None:
4343
torch_model = TorchGenerator()
44-
with self.assertRaises(NotImplementedError):
45-
torch_model.predict(torch.zeros(1))
46-
47-
def test_TorchModelGen(self) -> None:
48-
torch_model = TorchGenerator()
49-
with self.assertRaises(NotImplementedError):
50-
torch_model.gen(
44+
cases = {
45+
"predict": lambda: torch_model.predict(torch.zeros(1)),
46+
"gen": lambda: torch_model.gen(
5147
n=1,
5248
search_space_digest=self.search_space_digest,
5349
torch_opt_config=self.torch_opt_config,
54-
)
50+
),
51+
"cross_validate": lambda: torch_model.cross_validate(
52+
datasets=[self.dataset],
53+
X_test=torch.ones(1),
54+
search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]),
55+
),
56+
}
57+
for method_name, call in cases.items():
58+
with self.subTest(method=method_name):
59+
with self.assertRaises(NotImplementedError):
60+
call()
5561

5662
def test_NumpyTorchBestPoint(self) -> None:
5763
torch_model = TorchGenerator()
@@ -60,12 +66,3 @@ def test_NumpyTorchBestPoint(self) -> None:
6066
torch_opt_config=self.torch_opt_config,
6167
)
6268
self.assertIsNone(x)
63-
64-
def test_TorchModelCrossValidate(self) -> None:
65-
torch_model = TorchGenerator()
66-
with self.assertRaises(NotImplementedError):
67-
torch_model.cross_validate(
68-
datasets=[self.dataset],
69-
X_test=torch.ones(1),
70-
search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]),
71-
)

0 commit comments

Comments
 (0)