Skip to content

Commit 3a77c2b

Browse files
blethammeta-codesync[bot]
authored andcommitted
Option to not include range bounds in SearchSpace.check_membership (facebook#4962)
Summary: Pull Request resolved: facebook#4962 Reviewed By: saitcakmak Differential Revision: D94707497 fbshipit-source-id: 9f827e607939a282269d4793e51eb62cb3b282f9
1 parent 327b230 commit 3a77c2b

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

ax/analysis/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def _prepare_modeled_arm_data(
377377
parameterization=arm.parameters,
378378
raise_error=False,
379379
check_all_parameters_present=True,
380+
check_range_bounds=False,
380381
):
381382
predictable_pairs.append((trial_index, arm))
382383
else:

ax/core/search_space.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def check_membership(
341341
parameterization: Mapping[str, TParamValue],
342342
raise_error: bool = False,
343343
check_all_parameters_present: bool = True,
344+
check_range_bounds: bool = True,
344345
) -> bool:
345346
"""Whether the given parameterization belongs in the search space.
346347
@@ -354,6 +355,10 @@ def check_membership(
354355
with detailed explanation of why.
355356
check_all_parameters_present: Ensure that parameterization specifies
356357
values for all parameters as expected by the search space.
358+
check_range_bounds: If False, only check that values for
359+
RangeParameters have the correct type, without enforcing
360+
the parameter bounds. Other parameter types (ChoiceParameter,
361+
FixedParameter, DerivedParameter) are still fully validated.
357362
358363
Returns:
359364
Whether the parameterization is contained in the search space.
@@ -366,12 +371,16 @@ def check_membership(
366371

367372
for name, value in parameterization.items():
368373
p = self.parameters[name]
369-
kwargs = (
370-
{"parameters": parameterization}
371-
if isinstance(p, DerivedParameter)
372-
else {}
373-
)
374-
if not p.validate(value=value, raises=False, **kwargs):
374+
if not check_range_bounds and isinstance(p, RangeParameter):
375+
is_valid = p.is_valid_type(value)
376+
else:
377+
kwargs = (
378+
{"parameters": parameterization}
379+
if isinstance(p, DerivedParameter)
380+
else {}
381+
)
382+
is_valid = p.validate(value=value, raises=False, **kwargs)
383+
if not is_valid:
375384
if raise_error:
376385
raise ValueError(
377386
f"{value} is not a valid value for "

ax/core/tests/test_search_space.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,47 @@ def test_CheckMembership(self) -> None:
303303
with self.assertRaises(ValueError):
304304
self.ss2.check_membership(p_dict, raise_error=True)
305305

306+
def test_CheckMembershipSkipRangeBounds(self) -> None:
307+
ss = SearchSpace(
308+
parameters=[
309+
RangeParameter(
310+
name="x",
311+
parameter_type=ParameterType.FLOAT,
312+
lower=0.0,
313+
upper=1.0,
314+
),
315+
ChoiceParameter(
316+
name="c",
317+
parameter_type=ParameterType.STRING,
318+
values=["a", "b"],
319+
),
320+
]
321+
)
322+
323+
# Value out of range: default check rejects it
324+
p_dict = {"x": 5.0, "c": "a"}
325+
self.assertFalse(ss.check_membership(p_dict))
326+
self.assertFalse(ss.check_membership(p_dict, check_range_bounds=True))
327+
328+
# With check_range_bounds=False, out-of-range value is accepted
329+
self.assertTrue(ss.check_membership(p_dict, check_range_bounds=False))
330+
331+
# Wrong type for range parameter is still rejected
332+
p_dict_wrong_type = {"x": "not_a_number", "c": "a"}
333+
self.assertFalse(
334+
ss.check_membership(p_dict_wrong_type, check_range_bounds=False)
335+
)
336+
with self.assertRaises(ValueError):
337+
ss.check_membership(
338+
p_dict_wrong_type, check_range_bounds=False, raise_error=True
339+
)
340+
341+
# Invalid choice value is still rejected
342+
p_dict_bad_choice = {"x": 5.0, "c": "invalid"}
343+
self.assertFalse(
344+
ss.check_membership(p_dict_bad_choice, check_range_bounds=False)
345+
)
346+
306347
def test_check_membership_df(self) -> None:
307348
"""Test vectorized membership check on DataFrames."""
308349
# Create test DataFrame with valid and invalid rows

0 commit comments

Comments
 (0)