Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ax/analysis/healthcheck/search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def search_space_boundary_proportions(
if isinstance(parameter, RangeParameter):
lower = parameter.lower
upper = parameter.upper
elif isinstance(parameter, ChoiceParameter) and parameter.is_ordered:
elif (
isinstance(parameter, ChoiceParameter)
and parameter.is_ordered
and all(isinstance(v, (int, float)) for v in parameter.values)
):
values = [
assert_is_instance(v, Union[int, float]) for v in parameter.values
]
Expand Down
32 changes: 32 additions & 0 deletions ax/analysis/healthcheck/tests/test_search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def test_search_space_boundary_proportions(self) -> None:
)
)
)

def test_search_space_boundary_proportions_string_ordered_choice(self) -> None:
"""Test that ordered choice parameters with string values are skipped."""
ss = SearchSpace(
parameters=[
RangeParameter(
name="float_range",
parameter_type=ParameterType.FLOAT,
lower=1.0,
upper=6.0,
),
ChoiceParameter(
name="string_ordered_choice",
parameter_type=ParameterType.STRING,
values=["option_a", "option_b", "option_c"],
is_ordered=True,
),
],
)

parameterizations: list[dict[str, None | bool | float | int | str]] = [
{"float_range": 1.0, "string_ordered_choice": "option_a"},
{"float_range": 3.0, "string_ordered_choice": "option_b"},
]

# Should not raise -- string ordered choice should be skipped
df = search_space_boundary_proportions(
search_space=ss, parameterizations=parameterizations
)
# Only float_range boundaries should be present (lower and upper)
self.assertEqual(len(df), 2)
self.assertTrue(all("float_range" in b for b in df["Boundary"].values))