Skip to content

Commit e1a2ad4

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
(1/5) Port helpers to OSS for the new Complexity Rating Healthcheck - is_unordered_choice, can_map_to_binary (facebook#4645)
Summary: Pull Request resolved: facebook#4645 Differential Revision: D88883178
1 parent f2e1c70 commit e1a2ad4

2 files changed

Lines changed: 159 additions & 1 deletion

File tree

ax/adapter/adapter_utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
OutcomeConstraint,
3636
ScalarizedOutcomeConstraint,
3737
)
38-
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
38+
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
3939
from ax.core.parameter_constraint import ParameterConstraint
4040
from ax.core.search_space import SearchSpace, SearchSpaceDigest
4141
from ax.core.types import TBounds, TCandidateMetadata
@@ -1321,3 +1321,56 @@ def _consolidate_comparisons(X: Tensor, Y: Tensor) -> tuple[Tensor, Tensor]:
13211321

13221322
X, Y, _ = consolidate_duplicates(X, Y)
13231323
return X, Y
1324+
1325+
1326+
def is_unordered_choice(
1327+
p: Parameter, min_choices: int | None = None, max_choices: int | None = None
1328+
) -> bool:
1329+
"""Returns whether a parameter is an unordered choice (categorical) parameter.
1330+
1331+
You can also specify `min_choices` and `max_choices` to restrict how many
1332+
possible values the parameter can take on.
1333+
1334+
Args:
1335+
p: Parameter.
1336+
min_choices: The minimum number of possible values for the parameter.
1337+
max_choices: The maximum number of possible values for the parameter.
1338+
1339+
Returns:
1340+
A boolean indicating whether p is an unordered choice parameter or not.
1341+
"""
1342+
if min_choices is not None and min_choices < 0:
1343+
raise UserInputError("`min_choices` must be a non-negative integer.")
1344+
if max_choices is not None and max_choices < 0:
1345+
raise UserInputError("`max_choices` must be a non-negative integer.")
1346+
if (
1347+
min_choices is not None
1348+
and max_choices is not None
1349+
and min_choices > max_choices
1350+
):
1351+
raise UserInputError("`min_choices` cannot be larger than than `max_choices`.")
1352+
return (
1353+
isinstance(p, ChoiceParameter)
1354+
and not p.is_ordered
1355+
and (min_choices is None or min_choices <= len(p.values))
1356+
and (max_choices is None or max_choices >= len(p.values))
1357+
)
1358+
1359+
1360+
def can_map_to_binary(p: Parameter) -> bool:
1361+
"""Returns whether a parameter can be transformed to a binary parameter.
1362+
1363+
Any choice/range parameters with exactly two values can be transformed to a
1364+
binary parameter.
1365+
1366+
Args:
1367+
p: Parameter.
1368+
1369+
Returns
1370+
A boolean indicating whether p can be transformed to a binary parameter.
1371+
"""
1372+
return (isinstance(p, ChoiceParameter) and len(p.values) == 2) or (
1373+
isinstance(p, RangeParameter)
1374+
and p.parameter_type == ParameterType.INT
1375+
and p.lower == p.upper - 1
1376+
)

ax/adapter/tests/test_adapter_utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from ax.adapter.adapter_utils import (
1313
_get_adapter_training_data,
1414
arm_to_np_array,
15+
can_map_to_binary,
1516
extract_search_space_digest,
1617
feasible_hypervolume,
18+
is_unordered_choice,
1719
process_contextual_datasets,
1820
transform_search_space,
1921
validate_and_apply_final_transform,
@@ -414,3 +416,106 @@ def test_validate_and_apply_final_transform_none_target_point(self) -> None:
414416

415417
# Assert: confirm target point remains None
416418
self.assertIsNone(target_p)
419+
420+
def test_is_unordered_choice(self) -> None:
421+
# Test cases where is_unordered_choice should return True
422+
# (with min_choices=3, max_choices=5)
423+
for p in [
424+
ChoiceParameter("p", ParameterType.INT, values=[0, 1, 2], is_ordered=False),
425+
ChoiceParameter(
426+
"p", ParameterType.INT, values=[0, 1, 2, 4, 5], is_ordered=False
427+
),
428+
ChoiceParameter(
429+
"p", ParameterType.STRING, values=["a", "b", "c", "d"], is_ordered=False
430+
),
431+
]:
432+
with self.subTest(p=p):
433+
self.assertTrue(is_unordered_choice(p, min_choices=3, max_choices=5))
434+
435+
# Test cases where is_unordered_choice should return False
436+
# (with min_choices=3, max_choices=5)
437+
for p in [
438+
# Too few choices
439+
ChoiceParameter("p", ParameterType.INT, values=[0, 1], is_ordered=False),
440+
# Ordered choice (INT)
441+
ChoiceParameter(
442+
"p", ParameterType.INT, values=[0, 1, 2, 4], is_ordered=True
443+
),
444+
# Range parameter (not a choice)
445+
RangeParameter("p", parameter_type=ParameterType.INT, lower=0, upper=3),
446+
# Ordered choice (STRING)
447+
ChoiceParameter(
448+
"p", ParameterType.STRING, values=["0", "1", "2"], is_ordered=True
449+
),
450+
]:
451+
with self.subTest(p=p):
452+
self.assertFalse(is_unordered_choice(p, min_choices=3, max_choices=5))
453+
454+
# Test error cases
455+
p = ChoiceParameter("p", ParameterType.INT, values=[0, 1, 2], is_ordered=False)
456+
with self.assertRaisesRegex(
457+
UserInputError, "`min_choices` must be a non-negative integer."
458+
):
459+
is_unordered_choice(p, min_choices=-3)
460+
with self.assertRaisesRegex(
461+
UserInputError, "`max_choices` must be a non-negative integer."
462+
):
463+
is_unordered_choice(p, max_choices=-1)
464+
with self.assertRaisesRegex(
465+
UserInputError, "`min_choices` cannot be larger than than `max_choices`."
466+
):
467+
is_unordered_choice(p, min_choices=3, max_choices=2)
468+
469+
def test_can_map_to_binary(self) -> None:
470+
# Test cases where can_map_to_binary should return True
471+
for p in [
472+
# Int range with exactly 2 values
473+
RangeParameter(
474+
name="p", parameter_type=ParameterType.INT, lower=0, upper=1
475+
),
476+
RangeParameter(
477+
name="p", parameter_type=ParameterType.INT, lower=3, upper=4
478+
),
479+
# Choice with exactly 2 values
480+
ChoiceParameter(
481+
name="p",
482+
parameter_type=ParameterType.INT,
483+
values=[0, 1],
484+
is_ordered=False,
485+
),
486+
ChoiceParameter(
487+
name="p",
488+
parameter_type=ParameterType.STRING,
489+
values=["a", "b"],
490+
is_ordered=False,
491+
),
492+
]:
493+
with self.subTest(p=p):
494+
self.assertTrue(can_map_to_binary(p))
495+
496+
# Test cases where can_map_to_binary should return False
497+
for p in [
498+
# Float range (continuous, not binary)
499+
RangeParameter(
500+
name="p", parameter_type=ParameterType.FLOAT, lower=0, upper=1
501+
),
502+
# Int range with more than 2 values
503+
RangeParameter(
504+
name="p", parameter_type=ParameterType.INT, lower=0, upper=3
505+
),
506+
# Choice with more than 2 values
507+
ChoiceParameter(
508+
name="p",
509+
parameter_type=ParameterType.INT,
510+
values=[0, 1, 2],
511+
is_ordered=False,
512+
),
513+
ChoiceParameter(
514+
name="p",
515+
parameter_type=ParameterType.STRING,
516+
values=["a", "b", "c"],
517+
is_ordered=False,
518+
),
519+
]:
520+
with self.subTest(p=p):
521+
self.assertFalse(can_map_to_binary(p))

0 commit comments

Comments
 (0)