Skip to content

Commit 5eed9cc

Browse files
David Erikssonfacebook-github-bot
David Eriksson
authored andcommitted
Update ChoiceParameter encoder (#3510)
Summary: `sort_values` wasn't properly extracted in the encoder. This leads to a bunch of log spew when decoding a choice parameter. Differential Revision: D71261719
1 parent af67610 commit 5eed9cc

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

ax/storage/json_store/encoders.py

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def choice_parameter_to_dict(parameter: ChoiceParameter) -> dict[str, Any]:
189189
"is_fidelity": parameter.is_fidelity,
190190
"target_value": parameter.target_value,
191191
"dependents": parameter.dependents if parameter.is_hierarchical else None,
192+
"sort_values": parameter.sort_values,
192193
}
193194

194195

ax/storage/json_store/tests/test_json_store.py

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
get_scheduler_options_batch_trial,
118118
get_search_space,
119119
get_sebo_acquisition_class,
120+
get_sorted_choice_parameter,
120121
get_sum_constraint1,
121122
get_sum_constraint2,
122123
get_surrogate,
@@ -172,6 +173,7 @@
172173
("BraninMetric", get_branin_metric),
173174
("ChainedInputTransform", get_chained_input_transform),
174175
("ChoiceParameter", get_choice_parameter),
176+
("ChoiceParameter", get_sorted_choice_parameter),
175177
# testing with non-default argument
176178
("DataLoaderConfig", partial(DataLoaderConfig, fit_out_of_design=True)),
177179
("Experiment", get_experiment_with_batch_and_single_trial),

ax/utils/testing/core_stubs.py

+10
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,16 @@ def get_ordered_choice_parameter() -> ChoiceParameter:
14881488
)
14891489

14901490

1491+
def get_sorted_choice_parameter() -> ChoiceParameter:
1492+
return ChoiceParameter(
1493+
name="y",
1494+
parameter_type=ParameterType.STRING,
1495+
values=["2", "1", "3"],
1496+
is_ordered=True,
1497+
sort_values=True,
1498+
)
1499+
1500+
14911501
def get_task_choice_parameter() -> ChoiceParameter:
14921502
return ChoiceParameter(
14931503
name="y",

0 commit comments

Comments
 (0)