Skip to content

Commit 60600e0

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Support merging FixedParameter with ChoiceParameter in transfer learning (#5102)
Summary: Pull Request resolved: #5102 When merging search spaces for transfer learning, a parameter may be FixedParameter in one experiment and ChoiceParameter in another (e.g., a parameter was fixed to a single value in the source but is tunable in the target). Previously this raised a ValueError. Now we merge them into a ChoiceParameter whose values include the union of the choice values and the fixed value. Reviewed By: saitcakmak Differential Revision: D98247197 fbshipit-source-id: cb4adf32e886ec26dc24d8f0a94e5e64ee20c58f
1 parent 17ac795 commit 60600e0

1 file changed

Lines changed: 35 additions & 2 deletions

File tree

ax/adapter/transfer_learning/utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def merge_parameters(
8181
If both are choice parameters, they will be merged into a choice parameter that
8282
includes the union of the values of the two parameters.
8383
84+
If one is a fixed parameter and the other a choice parameter, they will be
85+
merged into a choice parameter whose values include the fixed value.
86+
8487
If the parameters have dependents (for hierarchical search spaces), then the
8588
dependents will be merged together.
8689
"""
@@ -91,9 +94,12 @@ def merge_parameters(
9194
)
9295
p1_type = type(p1)
9396
p2_type = type(p2)
97+
allowed_mixed_pairs = (
98+
{FixedParameter, RangeParameter},
99+
{FixedParameter, ChoiceParameter},
100+
)
94101
if (
95-
p1_type is not p2_type
96-
and ({p1_type, p2_type} != {FixedParameter, RangeParameter})
102+
p1_type is not p2_type and ({p1_type, p2_type} not in allowed_mixed_pairs)
97103
) or p1.parameter_type != p2.parameter_type:
98104
raise ValueError(f"Cannot merge parameters of different types: {p1}, {p2}.")
99105
if isinstance(p1, RangeParameter) and isinstance(p2, RangeParameter):
@@ -140,6 +146,33 @@ def merge_parameters(
140146
upper=max(range_param.upper, range_param.cast(fixed_param.value)),
141147
)
142148
return new_range_param
149+
elif (
150+
isinstance(fixed_param := p1, FixedParameter)
151+
and isinstance(choice_param := p2, ChoiceParameter)
152+
) or (
153+
isinstance(fixed_param := p2, FixedParameter)
154+
and isinstance(choice_param := p1, ChoiceParameter)
155+
):
156+
# Merge FixedParameter into ChoiceParameter by including the fixed
157+
# value in the set of choice values.
158+
values = list(set(choice_param.values) | {fixed_param.value})
159+
return ChoiceParameter(
160+
name=p1.name,
161+
parameter_type=p1.parameter_type,
162+
values=values,
163+
is_ordered=choice_param.is_ordered,
164+
is_task=choice_param.is_task,
165+
is_fidelity=choice_param.is_fidelity,
166+
target_value=choice_param.target_value,
167+
sort_values=choice_param.sort_values,
168+
dependents=merge_dependents(
169+
# pyre-ignore[6]: p1/p2 are FixedParameter | ChoiceParameter here.
170+
p1=p1,
171+
# pyre-ignore[6]: p1/p2 are FixedParameter | ChoiceParameter here.
172+
p2=p2,
173+
reverse_param_config=reverse_param_config,
174+
),
175+
)
143176
elif isinstance(p1, ChoiceParameter) and isinstance(p2, ChoiceParameter):
144177
return ChoiceParameter(
145178
name=p1.name,

0 commit comments

Comments
 (0)