Skip to content

Commit 779b245

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Migrate TL adapter utils tests to non-fb location and drop shims (#5217)
Summary: Pull Request resolved: #5217 TL;DR: The code was migrated to OSS, but the tests were left behind with some BC re-export shims that were unused outside of tests. This moves everything to OSS and deletes the unused shims. The TL adapter utilities `get_joint_search_space`, `merge_dependents`, `merge_parameters` (in `ax/adapter/transfer_learning/utils.py`) and `get_mapped_parameter_names` (in `ax/adapter/transfer_learning/utils_torch.py`) were previously migrated out of `ax/fb/adapter/`, which left behind pure re-export shims at `ax/fb/adapter/utils.py` and `ax/fb/adapter/utils_torch.py`. The only remaining coverage for these functions lived in `ax/fb/adapter/tests/test_utils.py` and `test_utils_torch.py`, exercising the migrated code through those shims -- and the non-fb destination had no test coverage of its own. This moves both test files to `ax/adapter/transfer_learning/tests/`, switches their imports to the real non-fb modules (`ax.adapter.transfer_learning.utils`/`utils_torch`, `ax.core.auxiliary_source.AuxiliarySource`, `ax.utils.common.testutils.TestCase`), and removes the now-unused `get_unordered_choice`/`get_ordered_choice` helpers from `test_utils.py`. Since the two test files were the only callers of the `ax.fb.adapter.utils`/`utils_torch` shims repo-wide, the shims are deleted and the BUCK targets are updated accordingly: a new `test_utils` `python_unittest` is added under `ax/adapter/transfer_learning/BUCK`, and the old `test_utils` target, the orphaned `:utils` library, and the `utils_torch.py` src are removed from `ax/fb/adapter/BUCK`. The broader `ax.fb.core.auxiliary_source` shim is left in place; it still has many callers across admarket, pts, automl, storage, and docs, so cleaning it up is a separate effort. Reviewed By: hvarfner Differential Revision: D107429272 fbshipit-source-id: a1c1cdb2cd80632a751bbdf44a3c62f68e916630
1 parent f8106ac commit 779b245

2 files changed

Lines changed: 467 additions & 0 deletions

File tree

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from ax.adapter.transfer_learning.utils import (
9+
get_joint_search_space,
10+
merge_dependents,
11+
merge_parameters,
12+
)
13+
from ax.core.auxiliary_source import AuxiliarySource
14+
from ax.core.experiment import Experiment
15+
from ax.core.parameter import (
16+
ChoiceParameter,
17+
DerivedParameter,
18+
FixedParameter,
19+
Parameter,
20+
ParameterType,
21+
RangeParameter,
22+
)
23+
from ax.core.search_space import SearchSpace
24+
from ax.utils.common.testutils import TestCase
25+
from pyre_extensions import assert_is_instance, none_throws
26+
27+
28+
class AxFbCoreUtilsTest(TestCase):
29+
def test_get_joint_search_space(self) -> None:
30+
parameters: list[Parameter] = [
31+
RangeParameter(f"x{i}", parameter_type=ParameterType.INT, lower=0, upper=5)
32+
for i in range(3)
33+
]
34+
exp1 = Experiment(
35+
search_space=SearchSpace(parameters=parameters[:2]), name="test1"
36+
)
37+
exp2 = Experiment(
38+
search_space=SearchSpace(parameters=parameters[:2]), name="test2"
39+
)
40+
exp3 = Experiment(
41+
search_space=SearchSpace(parameters=parameters[1:]), name="test3"
42+
)
43+
aux_2 = AuxiliarySource(experiment=exp2)
44+
aux_3 = AuxiliarySource(experiment=exp3)
45+
aux_4 = AuxiliarySource(experiment=exp3, transfer_param_config={"x0": "x2"})
46+
for exp, aux_srcs, expected_params in (
47+
(exp1, [aux_2], {"x0", "x1"}),
48+
(exp1, [aux_2, aux_3], {"x0", "x1", "x2"}),
49+
(exp1, [aux_2, aux_4], {"x0", "x1"}),
50+
):
51+
self.assertEqual(
52+
set(
53+
get_joint_search_space(
54+
search_space=exp.search_space, auxiliary_sources=aux_srcs
55+
).parameters.keys()
56+
),
57+
expected_params,
58+
)
59+
60+
def test_get_joint_search_space_update_fixed_params(self) -> None:
61+
# test update fixed params
62+
range_param = RangeParameter(
63+
"x", parameter_type=ParameterType.INT, lower=0, upper=5
64+
)
65+
fixed_param1 = FixedParameter("y", parameter_type=ParameterType.INT, value=1)
66+
fixed_param2 = FixedParameter("y", parameter_type=ParameterType.INT, value=2)
67+
exp = Experiment(
68+
search_space=SearchSpace(parameters=[range_param, fixed_param1]),
69+
name="test1",
70+
)
71+
exp2 = Experiment(
72+
search_space=SearchSpace(parameters=[range_param, fixed_param2]),
73+
name="test2",
74+
)
75+
for update_fixed_params in [True, False]:
76+
aux2 = AuxiliarySource(
77+
experiment=exp2, update_fixed_params=update_fixed_params
78+
)
79+
ss_params = get_joint_search_space(
80+
search_space=exp.search_space, auxiliary_sources=[aux2]
81+
).parameters
82+
self.assertEqual(
83+
assert_is_instance(ss_params["y"], FixedParameter).value, 1
84+
)
85+
self.assertIn("x", ss_params)
86+
87+
def test_get_joint_search_space_with_hss_and_choice(self) -> None:
88+
ss1 = SearchSpace(
89+
parameters=[
90+
FixedParameter(
91+
"root",
92+
parameter_type=ParameterType.INT,
93+
value=1,
94+
dependents={1: ["learning_rate", "optimizer", "method"]},
95+
),
96+
ChoiceParameter(
97+
"learning_rate",
98+
parameter_type=ParameterType.FLOAT,
99+
values=[0.01, 0.05],
100+
),
101+
ChoiceParameter(
102+
"optimizer",
103+
parameter_type=ParameterType.STRING,
104+
values=["Adam", "SGD", "AdaGrad"],
105+
),
106+
ChoiceParameter(
107+
"method",
108+
parameter_type=ParameterType.STRING,
109+
values=["train", "eval"],
110+
),
111+
]
112+
)
113+
ss2 = SearchSpace(
114+
parameters=[
115+
FixedParameter(
116+
"root2",
117+
parameter_type=ParameterType.INT,
118+
value=1,
119+
dependents={1: ["lr", "optimizer"]},
120+
),
121+
ChoiceParameter(
122+
"lr", parameter_type=ParameterType.FLOAT, values=[0.01, 0.1]
123+
),
124+
ChoiceParameter(
125+
"optimizer",
126+
parameter_type=ParameterType.STRING,
127+
values=["Adam", "SGD"],
128+
),
129+
]
130+
)
131+
aux_src = AuxiliarySource(
132+
experiment=Experiment(search_space=ss2, name="test"),
133+
transfer_param_config={"learning_rate": "lr", "root": "root2"},
134+
update_fixed_params=False,
135+
)
136+
joint_ss = get_joint_search_space(search_space=ss1, auxiliary_sources=[aux_src])
137+
self.assertEqual(
138+
set(joint_ss.parameters.keys()),
139+
{"root", "learning_rate", "optimizer", "method"},
140+
)
141+
self.assertEqual(
142+
set(joint_ss["root"].dependents[1]),
143+
{"learning_rate", "optimizer", "method"},
144+
)
145+
self.assertEqual(
146+
assert_is_instance(
147+
joint_ss.parameters["learning_rate"], ChoiceParameter
148+
).values,
149+
[0.01, 0.05, 0.1],
150+
)
151+
self.assertEqual(
152+
set(
153+
assert_is_instance(
154+
joint_ss.parameters["optimizer"], ChoiceParameter
155+
).values
156+
),
157+
{"Adam", "SGD", "AdaGrad"},
158+
)
159+
160+
def test_merge_dependents(self) -> None:
161+
p_no_dependents = FixedParameter(
162+
"p", parameter_type=ParameterType.BOOL, value=True
163+
)
164+
# No dependents returns None.
165+
self.assertIsNone(
166+
merge_dependents(
167+
p1=p_no_dependents, p2=p_no_dependents, reverse_param_config={}
168+
)
169+
)
170+
p_dependents_1 = FixedParameter(
171+
"p1", parameter_type=ParameterType.INT, value=1, dependents={1: ["q"]}
172+
)
173+
p_dependents_2 = FixedParameter(
174+
"p2", parameter_type=ParameterType.INT, value=1, dependents={1: ["z"]}
175+
)
176+
# p1 dependents do not get renamed.
177+
self.assertEqual(
178+
merge_dependents(
179+
p1=p_dependents_1, p2=p_no_dependents, reverse_param_config={"q": "w"}
180+
),
181+
{1: ["q"]},
182+
)
183+
# p2 dependents get renamed.
184+
self.assertEqual(
185+
merge_dependents(
186+
p1=p_no_dependents, p2=p_dependents_1, reverse_param_config={"q": "w"}
187+
),
188+
{1: ["w"]},
189+
)
190+
# Merge p1 & p2 dependents with renaming for p2 only.
191+
self.assertEqual(
192+
set(
193+
none_throws(
194+
merge_dependents(
195+
p1=p_dependents_1,
196+
p2=p_dependents_2,
197+
reverse_param_config={"q": "w", "z": "v"},
198+
)
199+
)[1]
200+
),
201+
{"q", "v"},
202+
)
203+
204+
def test_merge_parameters(self) -> None:
205+
p_fixed = FixedParameter(
206+
name="fixed", parameter_type=ParameterType.BOOL, value=True
207+
)
208+
p_fixed_2 = FixedParameter(name="f2", parameter_type=ParameterType.INT, value=1)
209+
p_fixed_3 = FixedParameter(name="f3", parameter_type=ParameterType.INT, value=2)
210+
p_fixed_4 = FixedParameter(
211+
name="f4", parameter_type=ParameterType.INT, value=1, dependents={1: ["a"]}
212+
)
213+
with self.assertRaisesRegex(ValueError, "different names"):
214+
merge_parameters(p1=p_fixed, p2=p_fixed_2, reverse_param_config={})
215+
with self.assertRaisesRegex(ValueError, "different types"):
216+
merge_parameters(
217+
p1=p_fixed, p2=p_fixed_2, reverse_param_config={"f2": "fixed"}
218+
)
219+
# Check that it works with both values of update_fixed_params.
220+
for update_fixed_params in [True, False]:
221+
self.assertEqual(
222+
merge_parameters(
223+
p1=p_fixed_2,
224+
p2=p_fixed_3,
225+
reverse_param_config={"f3": "f2"},
226+
update_fixed_params=update_fixed_params,
227+
),
228+
FixedParameter(
229+
name="f2",
230+
parameter_type=ParameterType.INT,
231+
value=1,
232+
),
233+
)
234+
self.assertEqual(
235+
merge_parameters(
236+
p1=p_fixed_2, p2=p_fixed_4, reverse_param_config={"f4": "f2"}
237+
),
238+
FixedParameter(
239+
name="f2",
240+
parameter_type=ParameterType.INT,
241+
value=1,
242+
dependents={1: ["a"]},
243+
),
244+
)
245+
p_range_1 = RangeParameter(
246+
name="p", parameter_type=ParameterType.INT, lower=1, upper=3
247+
)
248+
p_range_2 = RangeParameter(
249+
name="p", parameter_type=ParameterType.INT, lower=0, upper=2
250+
)
251+
self.assertEqual(
252+
merge_parameters(p1=p_range_1, p2=p_range_2, reverse_param_config={}),
253+
RangeParameter(
254+
name="p", parameter_type=ParameterType.INT, lower=0, upper=3
255+
),
256+
)
257+
p_choice_1 = ChoiceParameter(
258+
name="p",
259+
parameter_type=ParameterType.STRING,
260+
values=["a", "b", "c"],
261+
dependents={"a": ["p1"], "c": ["p2"]},
262+
)
263+
p_choice_2 = ChoiceParameter(
264+
name="p", parameter_type=ParameterType.STRING, values=["a", "b", "d"]
265+
)
266+
self.assertEqual(
267+
merge_parameters(p1=p_choice_1, p2=p_choice_2, reverse_param_config={}),
268+
ChoiceParameter(
269+
name="p",
270+
parameter_type=ParameterType.STRING,
271+
values=["a", "b", "c", "d"],
272+
dependents={"a": ["p1"], "c": ["p2"]},
273+
),
274+
)
275+
276+
# FixedParameter + ChoiceParameter: fixed value already in choices.
277+
p_fixed_str = FixedParameter(
278+
name="p", parameter_type=ParameterType.STRING, value="a"
279+
)
280+
merged_fc = merge_parameters(
281+
p1=p_fixed_str, p2=p_choice_1, reverse_param_config={}
282+
)
283+
self.assertIsInstance(merged_fc, ChoiceParameter)
284+
merged_fc_choice = assert_is_instance(merged_fc, ChoiceParameter)
285+
self.assertEqual(set(merged_fc_choice.values), {"a", "b", "c"})
286+
# Dependents from the choice parameter are preserved.
287+
self.assertEqual(merged_fc_choice.dependents, {"a": ["p1"], "c": ["p2"]})
288+
289+
# FixedParameter + ChoiceParameter: fixed value NOT in choices.
290+
p_fixed_str_new = FixedParameter(
291+
name="p", parameter_type=ParameterType.STRING, value="z"
292+
)
293+
merged_fc2 = merge_parameters(
294+
p1=p_fixed_str_new, p2=p_choice_1, reverse_param_config={}
295+
)
296+
self.assertEqual(
297+
set(assert_is_instance(merged_fc2, ChoiceParameter).values),
298+
{"a", "b", "c", "z"},
299+
)
300+
301+
# Reversed order: ChoiceParameter as p1, FixedParameter as p2.
302+
merged_cf = merge_parameters(
303+
p1=p_choice_1, p2=p_fixed_str_new, reverse_param_config={}
304+
)
305+
self.assertEqual(
306+
set(assert_is_instance(merged_cf, ChoiceParameter).values),
307+
{"a", "b", "c", "z"},
308+
)
309+
310+
# DerivedParameter: same expression succeeds.
311+
p_derived_1 = DerivedParameter(
312+
name="d",
313+
parameter_type=ParameterType.FLOAT,
314+
expression_str="0.5 * x + 0.3 * y",
315+
)
316+
p_derived_2 = DerivedParameter(
317+
name="d",
318+
parameter_type=ParameterType.FLOAT,
319+
expression_str="0.5 * x + 0.3 * y",
320+
)
321+
merged = merge_parameters(
322+
p1=p_derived_1, p2=p_derived_2, reverse_param_config={}
323+
)
324+
self.assertIsInstance(merged, DerivedParameter)
325+
self.assertEqual(
326+
assert_is_instance(merged, DerivedParameter).expression_str,
327+
"0.5 * x + 0.3 * y",
328+
)
329+
self.assertEqual(merged.name, "d")
330+
331+
# DerivedParameter: different expressions raises ValueError.
332+
p_derived_3 = DerivedParameter(
333+
name="d",
334+
parameter_type=ParameterType.FLOAT,
335+
expression_str="0.7 * x + 0.1 * y",
336+
)
337+
with self.assertRaisesRegex(ValueError, "different expressions"):
338+
merge_parameters(p1=p_derived_1, p2=p_derived_3, reverse_param_config={})
339+
340+
# DerivedParameter vs FixedParameter raises ValueError (type mismatch).
341+
p_fixed_float = FixedParameter(
342+
name="d", parameter_type=ParameterType.FLOAT, value=1.0
343+
)
344+
with self.assertRaisesRegex(ValueError, "different types"):
345+
merge_parameters(
346+
p1=p_derived_1,
347+
p2=p_fixed_float,
348+
reverse_param_config={},
349+
)

0 commit comments

Comments
 (0)